import torch

class Config:
    # resources path
    # 预训练模型位置 使用"hfl/chinese-roberta-wwm-ext"可以自动从huggingface下载；如果下载失败可以手动下载放到一个文件夹中
    pretrained_model_path = "/home/xinghai/pretrained-models/chinese-roberta-wwm-ext" # chinese-roberta-wwm-ext模型
    
    train_file_path = "./data/train_process.json" # 训练集
    test_file_path = "./data/test_process.json" # 测试集
    dev_file_path = "" # 暂无
    domain_label_file_path = "./data/domains.txt"
    intent_label_file_path = "./data/intents.txt" # 意图标签数据
    slot_label_file_path = "./data/slot_label.txt" # 槽值标签数据
    model_save_dir = "./checkpoints/" # 模型存储位置
    model_load_path = "./checkpoints/model.1.pt" # 想要加载的模型名称
    # util 
    domainlabel2id = {}
    id2domainlabel = {}
    with open(domain_label_file_path, 'r') as fp:
        domain_labels = fp.read().split('\n')
        for i, label in enumerate(domain_labels):
            domainlabel2id[label] = i
            id2domainlabel[i] = label
    # 意图标签和id的映射关系
    intentlabel2id = {}
    id2intentlabel = {}
    with open(intent_label_file_path, 'r') as fp:
        seq_labels = fp.read().split('\n')
        for i, label in enumerate(seq_labels):
            intentlabel2id[label] = i
            id2intentlabel[i] = label
    # 槽值标签和id的映射关系
    slotlabel2id = {}
    id2slotlabel = {}
    with open(slot_label_file_path,'r') as fp:
        slot_labels = fp.read().split('\n')
        for i,label in enumerate(slot_labels):
            slotlabel2id[label] = i
            id2slotlabel[i] = label
    # control parameters
    do_train = True
    do_eval = True
    do_save = True
    do_predict = True
    load_model = False
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # train parameters
    num_domain_labels = len(domainlabel2id)
    num_intent_labels = len(intentlabel2id) # 意图标签数量
    num_slot_labels = len(slotlabel2id) # 槽位标签数量
    hidden_size = 768 
    lr = 2e-5 # 学习率
    epoch = 20
    dropout = 0.1
    max_len = 32 # 输入文本最大长度
    batchsize = 64
